import matplotlib.pyplot as plt
import torch
import numpy as np
from parameters import Parameters
params = Parameters()
from utils.cmplxBatchNorm import magnitude


def combine_coils_RSOS(data):
    return torch.sqrt(torch.sum(magnitude(data) ** 2, dim=1, keepdim=True))

def tensorshow(x, sl_dims=(0, 0), rng=[0, 0.001]):
    plt.figure()

    if type(x) is np.ndarray:
        if np.iscomplexobj(x):
            x = np.abs(x)[sl_dims[0], sl_dims[1], :, :]
        else:
            x = np.sqrt(x[sl_dims[0], sl_dims[1], :, :, 0] ** 2 +
                               x[sl_dims[0], sl_dims[1], :, :, 1] ** 2)
    else:
        x = magnitude(x).cpu().data.numpy()[sl_dims[0], sl_dims[1], :, :] 

    plt.imshow(x, cmap='gray', vmin=rng[0], vmax=rng[1])
    plt.show()

def ntensorshow(x, sl_dims=(0, 0), rng=[0, 0.001], titles=None, saveFigs=False,figname=None):

    n_slices = x[0].shape[0]
    if figname is None:
        figname = [str(i).zfill(3) for i in range(0,n_slices)]

    for sl in range(0,n_slices):
        fig, axs = plt.subplots(1, len(x))
        i = 0
        for ax in axs:
            if titles is not None:
                ax.set_title(titles[i])

            img = x[i]
            if img.shape[1] > 1:
                img = combine_coils_RSOS(img)[sl, sl_dims[1], :, :]
            else:
                img = magnitude(img)[sl, sl_dims[1], :, :]

            img_mean = torch.mean(torch.reshape(img, [img.numel()]))
            img_std = torch.std(torch.reshape(img, [img.numel()]))

            ax.imshow(img.cpu().data.numpy(), cmap='gray', vmin=img_mean-1*img_std, vmax=img_mean+3*img_std)
            ax.axis('off')
            i += 1

        if saveFigs:
            fig.savefig(params.tensorboard_dir + figname[sl] +'.png', dpi=300)
            plt.close(fig)
        else:
            fig.show()
            fig.canvas.flush_events()
